package cn.lagou.homework

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

object IPDemo {
  def main(args: Array[String]): Unit = {
    // 1.创建SparkContext
    val conf = new SparkConf().setAppName(this.getClass.getCanonicalName.init).setMaster("local[*]")
    val sc = new SparkContext(conf)
    sc.setLogLevel("WARN")

    val httpData: RDD[Long] = sc.textFile("data/http.log")
      .map(x => ipConvert(x.split("\\|")(1)))

    val ipData: Array[(Long, Long, String)] = sc.textFile("data/ip.dat")
      .map { line =>
        val fields: Array[String] = line.split("\\|")
        (fields(2).toLong, fields(3).toLong, fields(6))
      }.collect()

    val ipBC: Broadcast[Array[(Long, Long, String)]] =
      sc.broadcast(ipData.sortBy(_._1))

    val results: Array[(String, Int)] = httpData.mapPartitions { iter =>
      val ipIno: Array[(Long, Long, String)] = ipBC.value
      iter.map { ip =>
        val city: String = getCityName(ip, ipIno)
        (city, 1)
      }
    }.reduceByKey(_ + _)
      .collect()
    results.sortBy(_._2)
      .foreach(println)

    // 关闭连接
    sc.stop()
  }

  // 定义一个IP转换函数 192.168.10.2  -- 192.168.25.3
  def ipConvert(ip: String): Long = {
    val arr: Array[Long] = ip.split("\\.")
      .map(_.toLong)
    var ipLong: Long = 0L
    for (i <- arr.indices) {
      val lon = scala.math.pow(255, i).toLong
      ipLong+= arr(i) * lon
    }
    ipLong
  }

  //寻找IP对应的城市
  def getCityName(ip: Long, ips: Array[(Long, Long, String)]): String = {
    var start = 0
    var end: Int = ips.length - 1
    var middle = 0

    while (start <= end) {
      middle = (start + end) / 2
      if ((ip >= ips(middle)._1) && (ip <= ips(middle)._2))
        return ips(middle)._3
      else if (ip < ips(middle)._1)
        end = middle - 1
      else
        start = middle + 1

    }
    "Unknow"
  }

}
